{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 生成网络模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-1-4c7193b29963>:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From D:\\Software\\Anaconda\\envs\\pyDL\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From D:\\Software\\Anaconda\\envs\\pyDL\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data\\train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From D:\\Software\\Anaconda\\envs\\pyDL\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data\\train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From D:\\Software\\Anaconda\\envs\\pyDL\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting MNIST_data\\t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data\\t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From D:\\Software\\Anaconda\\envs\\pyDL\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "# 载入数据\n",
    "mnist = input_data.read_data_sets('MNIST_data', one_hot=True)\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-2-574de5ddfecc>:15: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See @{tf.nn.softmax_cross_entropy_with_logits_v2}.\n",
      "\n",
      "Iter 0 Testing Accuracy: 0.8998\n",
      "Iter 1 Testing Accuracy: 0.9126\n",
      "Iter 2 Testing Accuracy: 0.9172\n",
      "Iter 3 Testing Accuracy: 0.9207\n",
      "Iter 4 Testing Accuracy: 0.9222\n",
      "Iter 5 Testing Accuracy: 0.9241\n",
      "Iter 6 Testing Accuracy: 0.9236\n",
      "Iter 7 Testing Accuracy: 0.9249\n",
      "Iter 8 Testing Accuracy: 0.9237\n",
      "Iter 9 Testing Accuracy: 0.9274\n",
      "Iter 10 Testing Accuracy: 0.9257\n",
      "Iter 11 Testing Accuracy: 0.925\n",
      "Iter 12 Testing Accuracy: 0.9262\n",
      "Iter 13 Testing Accuracy: 0.9264\n",
      "Iter 14 Testing Accuracy: 0.9262\n",
      "Iter 15 Testing Accuracy: 0.9272\n",
      "Iter 16 Testing Accuracy: 0.9279\n",
      "Iter 17 Testing Accuracy: 0.9271\n",
      "Iter 18 Testing Accuracy: 0.9257\n",
      "Iter 19 Testing Accuracy: 0.9269\n",
      "Iter 20 Testing Accuracy: 0.9278\n"
     ]
    }
   ],
   "source": [
    "# 批次的大小\n",
    "batch_size = 128\n",
    "n_batch = mnist.train.num_examples // batch_size\n",
    "\n",
    "x = tf.placeholder(tf.float32, [None,784])\n",
    "y = tf.placeholder(tf.float32, [None, 10])\n",
    "\n",
    "# 创建一个简单的神经网络\n",
    "W = tf.Variable(tf.zeros([784,10]))\n",
    "b = tf.Variable(tf.zeros([10]))\n",
    "prediction = tf.matmul(x,W) + b\n",
    "\n",
    "# 代价函数\n",
    "# loss = tf.reduce_mean(tf.square(y-prediction))\n",
    "loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))\n",
    "\n",
    "# 梯度下降法\n",
    "# train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)\n",
    "train_step = tf.train.AdamOptimizer().minimize(loss)\n",
    "\n",
    "# 初始化变量\n",
    "init = tf.global_variables_initializer()\n",
    "\n",
    "# 得到一个布尔型列表，存放结果是否正确\n",
    "correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1)) #argmax 返回一维张量中最大值索引\n",
    "\n",
    "# 求准确率\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) # 把布尔值转换为浮点型求平均数\n",
    "\n",
    "saver = tf.train.Saver()\n",
    "\n",
    "model_dir = 'net/8-1'\n",
    "if not os.path.exists(model_dir):\n",
    "    os.makedirs(model_dir)\n",
    "\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "    for epoch in range(21):\n",
    "        for batch in range(n_batch):\n",
    "            # 获得批次数据\n",
    "            batch_xs, batch_ys = mnist.train.next_batch(batch_size)\n",
    "            sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})\n",
    "        acc = sess.run(accuracy, feed_dict={x:mnist.test.images,y:mnist.test.labels})\n",
    "        print(\"Iter \" + str(epoch) + \" Testing Accuracy: \" + str(acc))\n",
    "    saver.save(sess, model_dir + '/test_net.ckpt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 载入训练好的模型\n",
    "* 运行完生成网络模型的代码后要重启一下kernel，再运行下面的代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-1-5e1a66207665>:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From D:\\Software\\Anaconda\\envs\\pyDL\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From D:\\Software\\Anaconda\\envs\\pyDL\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data\\train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From D:\\Software\\Anaconda\\envs\\pyDL\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data\\train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From D:\\Software\\Anaconda\\envs\\pyDL\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting MNIST_data\\t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data\\t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From D:\\Software\\Anaconda\\envs\\pyDL\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "# 载入数据\n",
    "mnist = input_data.read_data_sets('MNIST_data', one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-2-e200fc4f53e5>:15: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See @{tf.nn.softmax_cross_entropy_with_logits_v2}.\n",
      "\n",
      "INFO:tensorflow:Restoring parameters from net/8-1/test_net.ckpt\n",
      " Init Accuracy0.098\n",
      " Restore Accuracy: 0.9278\n"
     ]
    }
   ],
   "source": [
    "# 批次的大小\n",
    "batch_size = 128\n",
    "n_batch = mnist.train.num_examples // batch_size\n",
    "\n",
    "x = tf.placeholder(tf.float32, [None,784])\n",
    "y = tf.placeholder(tf.float32, [None, 10])\n",
    "\n",
    "# 创建一个简单的神经网络\n",
    "W = tf.Variable(tf.zeros([784,10]))\n",
    "b = tf.Variable(tf.zeros([10]))\n",
    "prediction = tf.matmul(x,W) + b\n",
    "\n",
    "# 代价函数\n",
    "# loss = tf.reduce_mean(tf.square(y-prediction))\n",
    "loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))\n",
    "\n",
    "# 梯度下降法\n",
    "# train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)\n",
    "train_step = tf.train.AdamOptimizer().minimize(loss)\n",
    "\n",
    "# 初始化变量\n",
    "init = tf.global_variables_initializer()\n",
    "\n",
    "# 得到一个布尔型列表，存放结果是否正确\n",
    "correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1)) #argmax 返回一维张量中最大值索引\n",
    "\n",
    "# 求准确率\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) # 把布尔值转换为浮点型求平均数\n",
    "\n",
    "saver = tf.train.Saver()\n",
    "\n",
    "model_dir = 'net/8-1'\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "    acc1 = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})\n",
    "    saver.restore(sess, model_dir + '/test_net.ckpt')\n",
    "    acc2 = sess.run(accuracy, feed_dict={x:mnist.test.images,y:mnist.test.labels})\n",
    "    print(\" Init Accuracy\" + str(acc1))\n",
    "    print(\" Restore Accuracy: \" + str(acc2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
